{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp models.timexer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import logging\n",
    "import warnings\n",
    "from nbdev.showdoc import show_doc\n",
    "from neuralforecast.common._model_checks import check_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TimeXer\n",
    "\n",
    "TimeXer empowers the canonical Transformer with the ability to reconcile endogenous and exogenous information, where patch-wise self-attention and variate-wise cross-attention are used simultaneously.\n",
    "\n",
    "**References**\n",
    "- [Yuxuan Wang, Haixu Wu, Jiaxiang Dong, Guo Qin, Haoran Zhang, Yong Liu, Yunzhong Qiu, Jianmin Wang, Mingsheng Long. \"TimeXer: Empowering Transformers for Time Series Forecasting with Exogenous Variables\"](https://arxiv.org/abs/2402.19072)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Figure 1. Architecture of TimeXer.](imgs_models/timexer.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from neuralforecast.losses.pytorch import MAE\n",
    "from neuralforecast.common._base_model import BaseModel\n",
    "from neuralforecast.common._modules import (\n",
    "    DataEmbedding_inverted, \n",
    "    PositionalEmbedding,\n",
    "    FullAttention,\n",
    "    AttentionLayer\n",
    ")\n",
    "from typing import Optional"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Auxiliary functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class FlattenHead(nn.Module):\n",
    "    def __init__(self, n_vars, nf, target_window, head_dropout=0):\n",
    "        super().__init__()\n",
    "        self.n_vars = n_vars\n",
    "        self.flatten = nn.Flatten(start_dim=-2)\n",
    "        self.linear = nn.Linear(nf, target_window)\n",
    "        self.dropout = nn.Dropout(head_dropout)\n",
    "\n",
    "    def forward(self, x):  # x: [bs x nvars x d_model x patch_num]\n",
    "        x = self.flatten(x)\n",
    "        x = self.linear(x)\n",
    "        x = self.dropout(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class Encoder(nn.Module):\n",
    "    def __init__(self, layers, norm_layer=None, projection=None):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.layers = nn.ModuleList(layers)\n",
    "        self.norm = norm_layer\n",
    "        self.projection = projection\n",
    "\n",
    "    def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):\n",
    "        for layer in self.layers:\n",
    "            x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)\n",
    "\n",
    "        if self.norm is not None:\n",
    "            x = self.norm(x)\n",
    "\n",
    "        if self.projection is not None:\n",
    "            x = self.projection(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class EncoderLayer(nn.Module):\n",
    "    def __init__(self, self_attention, cross_attention, d_model, d_ff=None,\n",
    "                 dropout=0.1, activation=\"relu\"):\n",
    "        super(EncoderLayer, self).__init__()\n",
    "        d_ff = d_ff or 4 * d_model\n",
    "        self.self_attention = self_attention\n",
    "        self.cross_attention = cross_attention\n",
    "        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)\n",
    "        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)\n",
    "        self.norm1 = nn.LayerNorm(d_model)\n",
    "        self.norm2 = nn.LayerNorm(d_model)\n",
    "        self.norm3 = nn.LayerNorm(d_model)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.activation = F.relu if activation == \"relu\" else F.gelu\n",
    "\n",
    "    def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):\n",
    "        B, L, D = cross.shape\n",
    "        x = x + self.dropout(self.self_attention(\n",
    "            x, x, x,\n",
    "            attn_mask=x_mask,\n",
    "            tau=tau, delta=None\n",
    "        )[0])\n",
    "        x = self.norm1(x)\n",
    "\n",
    "        x_glb_ori = x[:, -1, :].unsqueeze(1)\n",
    "        x_glb = torch.reshape(x_glb_ori, (B, -1, D))\n",
    "        x_glb_attn = self.dropout(self.cross_attention(\n",
    "            x_glb, cross, cross,\n",
    "            attn_mask=cross_mask,\n",
    "            tau=tau, delta=delta\n",
    "        )[0])\n",
    "        x_glb_attn = torch.reshape(x_glb_attn,\n",
    "                                   (x_glb_attn.shape[0] * x_glb_attn.shape[1], x_glb_attn.shape[2])).unsqueeze(1)\n",
    "        x_glb = x_glb_ori + x_glb_attn\n",
    "        x_glb = self.norm2(x_glb)\n",
    "\n",
    "        y = x = torch.cat([x[:, :-1, :], x_glb], dim=1)\n",
    "\n",
    "        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))\n",
    "        y = self.dropout(self.conv2(y).transpose(-1, 1))\n",
    "\n",
    "        return self.norm3(x + y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class EnEmbedding(nn.Module):\n",
    "    def __init__(self, n_vars, d_model, patch_len, dropout):\n",
    "        super(EnEmbedding, self).__init__()\n",
    "        # Patching\n",
    "        self.patch_len = patch_len\n",
    "\n",
    "        self.value_embedding = nn.Linear(patch_len, d_model, bias=False)\n",
    "        self.glb_token = nn.Parameter(torch.randn(1, n_vars, 1, d_model))\n",
    "        self.position_embedding = PositionalEmbedding(d_model)\n",
    "\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # do patching\n",
    "        n_vars = x.shape[1]\n",
    "        glb = self.glb_token.repeat((x.shape[0], 1, 1, 1))\n",
    "\n",
    "        x = x.unfold(dimension=-1, size=self.patch_len, step=self.patch_len)\n",
    "        x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))\n",
    "        # Input encoding\n",
    "        x = self.value_embedding(x) + self.position_embedding(x)\n",
    "        x = torch.reshape(x, (-1, n_vars, x.shape[-2], x.shape[-1]))\n",
    "        x = torch.cat([x, glb], dim=2)\n",
    "        x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))\n",
    "        return self.dropout(x), n_vars"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "\n",
    "class TimeXer(BaseModel):\n",
    "    \"\"\"\n",
    "    TimeXer\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `h`: int, Forecast horizon. <br>\n",
    "    `input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].<br>\n",
    "    `n_series`: int, number of time-series.<br>\n",
    "    `futr_exog_list`: str list, future exogenous columns.<br>\n",
    "    `hist_exog_list`: str list, historic exogenous columns.<br>\n",
    "    `stat_exog_list`: str list, static exogenous columns.<br>\n",
    "    `patch_len`: int, length of patches.<br>\n",
    "    `hidden_size`: int, dimension of the model.<br>\n",
    "    `n_heads`: int, number of heads.<br>\n",
    "    `e_layers`: int, number of encoder layers.<br>\n",
    "    `d_ff`: int, dimension of fully-connected layer.<br>\n",
    "    `factor`: int, attention factor.<br>\n",
    "    `dropout`: float, dropout rate.<br>\n",
    "    `use_norm`: bool, whether to normalize or not.<br>\n",
    "    `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `max_steps`: int=1000, maximum number of training steps.<br>\n",
    "    `learning_rate`: float=1e-3, Learning rate between (0, 1).<br>\n",
    "    `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.<br>\n",
    "    `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.<br>\n",
    "    `val_check_steps`: int=100, Number of training steps between every validation loss check.<br>\n",
    "    `batch_size`: int=32, number of different series in each batch.<br>\n",
    "    `valid_batch_size`: int=None, number of different series in each validation and test batch, if None uses batch_size.<br>\n",
    "    `windows_batch_size`: int=32, number of windows in each batch.<br>    \n",
    "    `inference_windows_batch_size`: int=32, number of windows to sample in each inference batch, -1 uses all.<br>\n",
    "    `start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.<br>\n",
    "    `step_size`: int=1, step size between each window of temporal data.<br>\n",
    "    `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>\n",
    "    `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.<br>\n",
    "    `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
    "    `alias`: str, optional,  Custom name of the model.<br>\n",
    "    `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>\n",
    "    `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
    "    `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
    "    `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
    "    `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
    "    `**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>\n",
    "\n",
    "    **Parameters:**<br>\n",
    "\n",
    "    **References**\n",
    "    - [Yuxuan Wang, Haixu Wu, Jiaxiang Dong, Guo Qin, Haoran Zhang, Yong Liu, Yunzhong Qiu, Jianmin Wang, Mingsheng Long. \"TimeXer: Empowering Transformers for Time Series Forecasting with Exogenous Variables\"](https://arxiv.org/abs/2402.19072)\n",
    "    \"\"\"\n",
    "\n",
    "    # Class attributes\n",
    "    EXOGENOUS_FUTR = True\n",
    "    EXOGENOUS_HIST = False\n",
    "    EXOGENOUS_STAT = False\n",
    "    MULTIVARIATE = True    # If the model produces multivariate forecasts (True) or univariate (False)\n",
    "    RECURRENT = False       # If the model produces forecasts recursively (True) or direct (False)\n",
    "\n",
    "    def __init__(self,\n",
    "                 h,\n",
    "                 input_size,\n",
    "                 n_series,\n",
    "                 futr_exog_list = None,\n",
    "                 hist_exog_list = None,\n",
    "                 stat_exog_list = None,\n",
    "                 exclude_insample_y: bool = False,\n",
    "                 patch_len: int = 16,\n",
    "                 hidden_size: int = 512,\n",
    "                 n_heads: int = 8,\n",
    "                 e_layers: int = 2,\n",
    "                 d_ff: int = 2048,\n",
    "                 factor: int = 1,\n",
    "                 dropout: float = 0.1,\n",
    "                 use_norm: bool = True,\n",
    "                 loss = MAE(),\n",
    "                 valid_loss = None,\n",
    "                 max_steps: int = 1000,\n",
    "                 learning_rate: float = 1e-3,\n",
    "                 num_lr_decays: int = -1,\n",
    "                 early_stop_patience_steps: int =-1,\n",
    "                 val_check_steps: int = 100,\n",
    "                 batch_size: int = 32,\n",
    "                 valid_batch_size: Optional[int] = None,\n",
    "                 windows_batch_size = 32,\n",
    "                 inference_windows_batch_size = 32,\n",
    "                 start_padding_enabled = False,\n",
    "                 step_size: int = 1,\n",
    "                 scaler_type: str = 'identity',\n",
    "                 random_seed: int = 1,\n",
    "                 drop_last_loader: bool = False,\n",
    "                 alias: Optional[str] = None,\n",
    "                 optimizer = None,\n",
    "                 optimizer_kwargs = None,\n",
    "                 lr_scheduler = None,\n",
    "                 lr_scheduler_kwargs = None,\n",
    "                 dataloader_kwargs = None,\n",
    "                 **trainer_kwargs):\n",
    "        \n",
    "        super(TimeXer, self).__init__(h=h,\n",
    "                                    input_size=input_size,\n",
    "                                    n_series=n_series,\n",
    "                                    futr_exog_list=futr_exog_list,\n",
    "                                    hist_exog_list=hist_exog_list,\n",
    "                                    stat_exog_list=stat_exog_list,\n",
    "                                    exclude_insample_y=exclude_insample_y,\n",
    "                                    loss=loss,\n",
    "                                    valid_loss=valid_loss,\n",
    "                                    max_steps=max_steps,\n",
    "                                    learning_rate=learning_rate,\n",
    "                                    num_lr_decays=num_lr_decays,\n",
    "                                    early_stop_patience_steps=early_stop_patience_steps,\n",
    "                                    val_check_steps=val_check_steps,\n",
    "                                    batch_size=batch_size,\n",
    "                                    valid_batch_size=valid_batch_size,\n",
    "                                    windows_batch_size=windows_batch_size,\n",
    "                                    inference_windows_batch_size=inference_windows_batch_size,\n",
    "                                    start_padding_enabled=start_padding_enabled,\n",
    "                                    step_size=step_size,\n",
    "                                    scaler_type=scaler_type,\n",
    "                                    random_seed=random_seed,\n",
    "                                    drop_last_loader=drop_last_loader,\n",
    "                                    alias=alias,\n",
    "                                    optimizer=optimizer,\n",
    "                                    optimizer_kwargs=optimizer_kwargs,\n",
    "                                    lr_scheduler=lr_scheduler,\n",
    "                                    lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
    "                                    dataloader_kwargs=dataloader_kwargs,\n",
    "                                    **trainer_kwargs)\n",
    "        \n",
    "        self.enc_in = n_series\n",
    "        self.hidden_size = hidden_size\n",
    "        self.n_heads = n_heads\n",
    "        self.e_layers = e_layers\n",
    "        self.d_ff = d_ff\n",
    "        self.dropout = dropout\n",
    "        self.factor = factor\n",
    "        self.patch_len = patch_len\n",
    "        self.use_norm = use_norm\n",
    "        self.patch_num = int(input_size // self.patch_len)\n",
    "\n",
    "        # Architecture\n",
    "        self.en_embedding = EnEmbedding(n_series, self.hidden_size, self.patch_len, self.dropout)\n",
    "        self.ex_embedding = DataEmbedding_inverted(input_size, self.hidden_size, self.dropout)\n",
    "\n",
    "        self.encoder = Encoder(\n",
    "            [\n",
    "                EncoderLayer(\n",
    "                    AttentionLayer(\n",
    "                        FullAttention(False, self.factor, attention_dropout=self.dropout,\n",
    "                                      output_attention=False),\n",
    "                        self.hidden_size, self.n_heads),\n",
    "                    AttentionLayer(\n",
    "                        FullAttention(False, self.factor, attention_dropout=self.dropout,\n",
    "                                      output_attention=False),\n",
    "                        self.hidden_size, self.n_heads),\n",
    "                    self.hidden_size,\n",
    "                    self.d_ff,\n",
    "                    dropout=self.dropout,\n",
    "                    activation='relu',\n",
    "                )\n",
    "                for l in range(self.e_layers)\n",
    "            ],\n",
    "            norm_layer=torch.nn.LayerNorm(self.hidden_size)\n",
    "        )\n",
    "        self.head_nf = self.hidden_size * (self.patch_num + 1)\n",
    "        self.head = FlattenHead(self.enc_in, self.head_nf, h * self.loss.outputsize_multiplier,\n",
    "                                head_dropout=self.dropout)\n",
    "        \n",
    "    def forecast(self, x_enc, x_mark_enc):\n",
    "        if self.use_norm:\n",
    "            # Normalization from Non-stationary Transformer\n",
    "            means = x_enc.mean(1, keepdim=True).detach()\n",
    "            x_enc = x_enc - means\n",
    "            stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)\n",
    "            x_enc /= stdev\n",
    "\n",
    "        _, _, N = x_enc.shape\n",
    "\n",
    "        en_embed, n_vars = self.en_embedding(x_enc.permute(0, 2, 1))\n",
    "        ex_embed = self.ex_embedding(x_enc, x_mark_enc)\n",
    "\n",
    "        enc_out = self.encoder(en_embed, ex_embed)\n",
    "        enc_out = torch.reshape(\n",
    "            enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))\n",
    "        # z: [bs x nvars x d_model x patch_num]\n",
    "        enc_out = enc_out.permute(0, 1, 3, 2)\n",
    "\n",
    "        dec_out = self.head(enc_out)  # z: [bs x nvars x h * n_outputs]\n",
    "        dec_out = dec_out.permute(0, 2, 1)\n",
    "\n",
    "        if self.use_norm:\n",
    "            # De-Normalization from Non-stationary Transformer\n",
    "            dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.h * self.loss.outputsize_multiplier, 1))\n",
    "            dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.h * self.loss.outputsize_multiplier, 1))\n",
    "\n",
    "        return dec_out\n",
    "    \n",
    "    def forward(self, windows_batch):\n",
    "        insample_y = windows_batch['insample_y']\n",
    "        futr_exog = windows_batch['futr_exog']\n",
    "        \n",
    "        if self.futr_exog_size > 0:\n",
    "            x_mark_enc = futr_exog[:, :, :self.input_size, :]\n",
    "            B, V, T, D = x_mark_enc.shape\n",
    "            x_mark_enc = x_mark_enc.reshape(B, T, V*D)\n",
    "        else:\n",
    "            x_mark_enc = None\n",
    "\n",
    "        y_pred = self.forecast(insample_y, x_mark_enc)\n",
    "        y_pred = y_pred.reshape(insample_y.shape[0],\n",
    "                                self.h,\n",
    "                                -1)\n",
    "        return y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TimeXer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TimeXer.fit, name='TimeXer.fit')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TimeXer.predict, name='TimeXer.predict')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Unit tests for models\n",
    "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n",
    "logging.getLogger(\"lightning_fabric\").setLevel(logging.ERROR)\n",
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter(\"ignore\")\n",
    "    check_model(TimeXer, [\"airpassengers\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Usage example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.models import TimeXer\n",
    "from neuralforecast.losses.pytorch import MSE\n",
    "from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic, augment_calendar_df\n",
    "\n",
    "AirPassengersPanel, calendar_cols = augment_calendar_df(df=AirPassengersPanel, freq='M')\n",
    "\n",
    "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train\n",
    "Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
    "\n",
    "model = TimeXer(h=12,\n",
    "                input_size=24,\n",
    "                n_series=2,\n",
    "                futr_exog_list=[\"trend\", \"month\"],\n",
    "                patch_len=12,\n",
    "                hidden_size=128,\n",
    "                n_heads=16,\n",
    "                e_layers=2,\n",
    "                d_ff=256,\n",
    "                factor=1,\n",
    "                dropout=0.1,\n",
    "                use_norm=True,\n",
    "                loss=MSE(),\n",
    "                valid_loss=MAE(),\n",
    "                early_stop_patience_steps=3,\n",
    "                batch_size=32)\n",
    "\n",
    "fcst = NeuralForecast(models=[model], freq='ME')\n",
    "fcst.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n",
    "forecasts = fcst.predict(futr_df=Y_test_df)\n",
    "\n",
    "# Plot predictions\n",
    "fig, ax = plt.subplots(1, 1, figsize = (20, 7))\n",
    "Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])\n",
    "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
    "plot_df = pd.concat([Y_train_df, plot_df])\n",
    "\n",
    "plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "plt.plot(plot_df['ds'], plot_df['TimeXer'], c='blue', label='Forecast')\n",
    "ax.set_title('AirPassengers Forecast', fontsize=22)\n",
    "ax.set_ylabel('Monthly Passengers', fontsize=20)\n",
    "ax.set_xlabel('Year', fontsize=20)\n",
    "ax.legend(prop={'size': 15})\n",
    "ax.grid()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
